{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "V4B-KoWi1HDP"
   },
   "source": [
    "# Text Classification Using Embeddings\n",
    "This notebook shows how to build a classifiers using Cohere's embeddings.\n",
    "<img src=\"https://github.com/cohere-ai/notebooks/raw/main/notebooks/images/simple-classifier-embeddings.png\"\n",
    "style=\"width:100%; max-width:600px\"\n",
    "alt=\"first we embed the text in the dataset, then we use that to train a classifier\"/>\n",
    "\n",
    "The example classification task here will be sentiment analysis of film reviews. We'll train a simple classifier to detect whether a film review is negative (class 0) or positive (class 1).\n",
    "\n",
    "We'll go through the following steps:\n",
    "\n",
    "1. Get the dataset\n",
    "2. Get the embeddings of the reviews (for both the training set and the test set).\n",
    "3. Train a classifier using the training set\n",
    "4. Evaluate the performance of the classifier on the testing set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XdM6oEyo9oUp"
   },
   "outputs": [],
   "source": [
    "# Let's first install Cohere's python SDK\n",
    "# TODO: upgrade to \"cohere>5\"",
"!pip install \"cohere<5\" scikit-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you're running an older version of the SDK you'll want to upgrade it, like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install --upgrade cohere"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d0DIiCJoe_-_"
   },
   "source": [
    "## 1. Get the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "bAVd49D6BjGK"
   },
   "outputs": [],
   "source": [
    "import cohere\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_colwidth', None)\n",
    "\n",
    "# Get the SST2 training and test sets\n",
    "df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\\t', header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BYhG9GO3B_Gd",
    "outputId": "e16f74d7-f6b1-44a6-db0b-49f21a3862e6"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>a stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>apparently reassembled from the cutting room floor of any given daytime soap</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>they presume their audience wo n't sit still for a sociology lesson , however entertainingly presented , so they trot out the conventional science fiction elements of bug eyed monsters and futuristic women in skimpy clothes</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>this is a visually stunning rumination on love , memory , history and the war between art and commerce</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>jonathan parker 's bartleby should have been the be all end all of the modern office anomie films</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                 0  \\\n",
       "0                                                                                                                          a stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films   \n",
       "1                                                                                                                                                     apparently reassembled from the cutting room floor of any given daytime soap   \n",
       "2  they presume their audience wo n't sit still for a sociology lesson , however entertainingly presented , so they trot out the conventional science fiction elements of bug eyed monsters and futuristic women in skimpy clothes   \n",
       "3                                                                                                                           this is a visually stunning rumination on love , memory , history and the war between art and commerce   \n",
       "4                                                                                                                                jonathan parker 's bartleby should have been the be all end all of the modern office anomie films   \n",
       "\n",
       "   1  \n",
       "0  1  \n",
       "1  0  \n",
       "2  0  \n",
       "3  1  \n",
       "4  1  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's glance at the dataset\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7KIMt6G7dlGC"
   },
   "source": [
    "We'll only use a subset of the training and testing datasets in this example. We'll only use 500 examples since this is a toy example. You'll want to increase the number to get better performance and evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `train_test_split` method splits arrays or matrices into random train and test subsets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the number of examples from the dataset\n",
    "num_examples = 500\n",
    "# Create a dataframe that\n",
    "df_sample = df.sample(num_examples)\n",
    "\n",
    "# Split into training and testing sets\n",
    "sentences_train, sentences_test, labels_train, labels_test = train_test_split(\n",
    "            list(df_sample[0]), list(df_sample[1]), test_size=0.25, random_state=0)\n",
    "\n",
    "# The embeddings endpoint can take up to 96 texts, so we'll have to truncate \n",
    "# sentences_train, sentences_test, labels_train, and labels_test. \n",
    "\n",
    "sentences_train = sentences_train[:95]\n",
    "sentences_test = sentences_test[:95]\n",
    "\n",
    "labels_train = labels_train[:95]\n",
    "labels_test = labels_test[:95]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aymSE9YFEymy"
   },
   "source": [
    "## 2. Set up the Cohere client and get the embeddings of the reviews\n",
    "We're now ready to retrieve the embeddings from the API. You'll need your API key for this next cell. [Sign up to Cohere](https://os.cohere.ai/) and get one if you haven't yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "qqfWFS1RDE54"
   },
   "outputs": [],
   "source": [
    "# Add the model name, API key, URL, etc.\n",
    "model_name = \"embed-english-v3.0\"\n",
    "api_key = \"\"\n",
    "\n",
    "# Here, we're setting up the data objects we'll pass to the embeds endpoint.\n",
    "input_type = \"classification\"\n",
    "\n",
    "# Create and retrieve a Cohere API key from dashboard.cohere.ai\n",
    "co = cohere.Client(api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed the training set\n",
    "embeddings_train = co.embed(texts=sentences_train,\n",
    "                            model=model_name,\n",
    "                            input_type=input_type\n",
    "                            ).embeddings\n",
    "\n",
    "# Embed the testing set\n",
    "embeddings_test = co.embed(texts=sentences_test,\n",
    "                           model=model_name,\n",
    "                           input_type=input_type\n",
    "                            ).embeddings\n",
    "\n",
    "# Here we are using the endpoint co.embed() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the ordering of the arguments is important. If you put `input_type` in before `model_name`, you'll get an error."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rhg6HkuXeFF_"
   },
   "source": [
    "We now have two sets of embeddings, `embeddings_train` contains the embeddings of the training  sentences while `embeddings_test` contains the embeddings of the testing sentences.\n",
    "\n",
    "Curious what an embedding looks like? we can print it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AFWgw-l7eXRG",
    "outputId": "f958e3ff-f6b0-457b-d0e9-9acad20a7e4e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Review text: the script was reportedly rewritten a dozen times either 11 times too many or else too few\n",
      "Embedding vector: [1.1531117, -0.8543223, -1.2496399, -0.28317127, -0.75870246, 0.5373464, 0.63233083, 0.5766576, 1.8336298, 0.44203663]\n"
     ]
    }
   ],
   "source": [
    "print(f\"Review text: {sentences_train[0]}\")\n",
    "print(f\"Embedding vector: {embeddings_train[0][:10]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gTjZkSXEeoMB"
   },
   "source": [
    "## 3. Train a classifier using the training set\n",
    "Now that we have the embedding we can train our classifier. We'll use an SVM from sklearn."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1FGyYoHSCK03",
    "outputId": "d9e09ecb-e569-47a3-8c66-41b076ea5d42"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(steps=[('standardscaler', StandardScaler()),\n",
       "                ('svc', SVC(class_weight='balanced'))])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# import SVM classifier code\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "# Initialize a support vector machine, with class_weight='balanced' because \n",
    "# our training set has roughly an equal amount of positive and negative \n",
    "# sentiment sentences\n",
    "svm_classifier = make_pipeline(StandardScaler(), SVC(class_weight='balanced')) \n",
    "\n",
    "# fit the support vector machine\n",
    "svm_classifier.fit(embeddings_train, labels_train)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VuAjDXjzfWJ7"
   },
   "source": [
    "## 4. Evaluate the performance of the classifier on the testing set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NrHAN510fWlm",
    "outputId": "3036bf44-9b71-4698-859a-1f55f0ecc282"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation accuracy on Large is 91.2%!\n"
     ]
    }
   ],
   "source": [
    "# get the score from the test set, and print it out to screen!\n",
    "score = svm_classifier.score(embeddings_test, labels_test)\n",
    "print(f\"Validation accuracy on is {100*score}%!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You may get a slightly different number when you run this code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0WnfBgA-OkKL",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This was a small scale example, meant as a proof of concept and designed to illustrate how you can build a custom classifier quickly using a small amount of labelled data and Cohere's embeddings. Increase the number of training examples to achieve better performance on this task."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Text Classification Using Embeddings.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
